Source code for nlp_architect.models.gnmt.utils.misc_utils

# ******************************************************************************
# Copyright 2017-2018 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ******************************************************************************
# Changes Made from original:
#   import paths
# ******************************************************************************
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: skip-file
"""Generally useful utility functions."""
from __future__ import print_function

import codecs
import collections
import json
import math
import os
import sys
import time
from distutils import version

import tensorflow as tf


[docs]def check_tensorflow_version(): min_tf_version = "1.4.0-dev20171024" if (version.LooseVersion(tf.__version__) < version.LooseVersion(min_tf_version)): raise EnvironmentError("Tensorflow version must >= %s" % min_tf_version)
[docs]def safe_exp(value): """Exponentiation with catching of overflow error.""" try: ans = math.exp(value) except OverflowError: ans = float("inf") return ans
[docs]def load_hparams(model_dir): """Load hparams from an existing model directory.""" hparams_file = os.path.join(model_dir, "hparams") if tf.gfile.Exists(hparams_file): print_out("# Loading hparams from %s" % hparams_file) with codecs.getreader("utf-8")(tf.gfile.GFile(hparams_file, "rb")) as f: try: hparams_values = json.load(f) hparams = tf.contrib.training.HParams(**hparams_values) except ValueError: print_out(" can't load hparams file") return None return hparams else: return None
[docs]def maybe_parse_standard_hparams(hparams, hparams_path): """Override hparams values with existing standard hparams config.""" if hparams_path and tf.gfile.Exists(hparams_path): print_out("# Loading standard hparams from %s" % hparams_path) with codecs.getreader("utf-8")(tf.gfile.GFile(hparams_path, "rb")) as f: hparams.parse_json(f.read()) return hparams
[docs]def save_hparams(out_dir, hparams): """Save hparams.""" hparams_file = os.path.join(out_dir, "hparams") print_out(" saving hparams to %s" % hparams_file) with codecs.getwriter("utf-8")(tf.gfile.GFile(hparams_file, "wb")) as f: f.write(hparams.to_json(indent=4, sort_keys=True))
[docs]def debug_tensor(s, msg=None, summarize=10): """Print the shape and value of a tensor at test time. Return a new tensor.""" if not msg: msg = s.name return tf.Print(s, [tf.shape(s), s], msg + " ", summarize=summarize)
[docs]def add_summary(summary_writer, global_step, tag, value): """Add a new summary to the current summary_writer. Useful to log things that are not part of the training graph, e.g., tag=BLEU. """ summary = tf.Summary(value=[tf.Summary.Value(tag=tag, simple_value=value)]) summary_writer.add_summary(summary, global_step)
[docs]def get_config_proto(log_device_placement=False, allow_soft_placement=True, num_intra_threads=0, num_inter_threads=0): # GPU options: # https://www.tensorflow.org/versions/r0.10/how_tos/using_gpu/index.html config_proto = tf.ConfigProto( log_device_placement=log_device_placement, allow_soft_placement=allow_soft_placement) config_proto.gpu_options.allow_growth = True # CPU threads options if num_intra_threads: config_proto.intra_op_parallelism_threads = num_intra_threads if num_inter_threads: config_proto.inter_op_parallelism_threads = num_inter_threads return config_proto
[docs]def format_text(words): """Convert a sequence words into sentence.""" if (not hasattr(words, "__len__") # for numpy array and not isinstance(words, collections.Iterable)): words = [words] return b" ".join(words)
[docs]def format_bpe_text(symbols, delimiter=b"@@"): """Convert a sequence of bpe words into sentence.""" words = [] word = b"" if isinstance(symbols, str): symbols = symbols.encode() delimiter_len = len(delimiter) for symbol in symbols: if len(symbol) >= delimiter_len and symbol[-delimiter_len:] == delimiter: word += symbol[:-delimiter_len] else: # end of a word word += symbol words.append(word) word = b"" return b" ".join(words)
[docs]def format_spm_text(symbols): """Decode a text in SPM (https://github.com/google/sentencepiece) format.""" return u"".join(format_text(symbols).decode("utf-8").split()).replace( u"\u2581", u" ").strip().encode("utf-8")